{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.13571768, 0.9907476 , 0.80379736], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiyElEQVR4nO3de3BU52H38d+u9iJ02dUFtGsFKdDYNVYxNAaMt+lMOkFFSTVJHPOH46GEukw8JsJjTMYzprXxJG9mxDjv2yRubdI2re2ZNqZDJjg1xUlUYctJLAOWwREXq06LLRXYFSC0Kwm0K+0+7x+LTliQHYmL9pH4fmbOaHXOs9KzJ2S/PmePdl3GGCMAACzkzvcEAAD4MEQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGCtvEXqmWee0bx581RYWKjly5dr3759+ZoKAMBSeYnUv/3bv2nTpk168skn9fbbb2vx4sVqaGhQb29vPqYDALCUKx9vMLt8+XItW7ZMf/d3fydJymQyqqmp0UMPPaTHHntsqqcDALCUZ6p/YSqVUkdHhzZv3uysc7vdqq+vV3t7+7j3SSaTSiaTzveZTEZ9fX2qrKyUy+W67nMGAFxbxhgNDAyourpabveHn9Sb8kidPn1a6XRaoVAoZ30oFNK777477n2am5v1jW98YyqmBwCYQj09PZo7d+6Hbp/ySF2JzZs3a9OmTc738XhctbW16unpUSAQyOPMAABXIpFIqKamRqWlpR85bsojNXv2bBUUFCgWi+Wsj8ViCofD497H7/fL7/dftj4QCBApAJjGftdLNlN+dZ/P59OSJUvU2trqrMtkMmptbVUkEpnq6QAALJaX032bNm3S2rVrtXTpUt1555367ne/q6GhId1///35mA4AwFJ5idS9996rU6dOacuWLYpGo/rDP/xD/fSnP73sYgoAwI0tL38ndbUSiYSCwaDi8TivSQHANDTR53Heuw8AYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtSYdqddff12f//znVV1dLZfLpZdeeilnuzFGW7Zs0U033aRZs2apvr5e7733Xs6Yvr4+rV69WoFAQGVlZVq3bp0GBwev6oEAAGaeSUdqaGhIixcv1jPPPDPu9qeeekpPP/20vv/972vv3r0qLi5WQ0ODhoeHnTGrV6/W4cOH1dLSol27dun111/XAw88cOWPAgAwM5mrIMns3LnT+T6TyZhwOGy+/e1vO+v6+/uN3+83L774ojHGmCNHjhhJZv/+/c6YV155xbhcLnP8+PEJ/d54PG4kmXg8fjXTBwDkyUSfx6/pa1LHjh1TNBpVfX29sy4YDGr58uVqb2+XJLW3t6usrExLly51xtTX18vtdmvv3r3j/txkMqlEIpGzAABmvmsaqWg0KkkKhUI560OhkLMtGo2qqqoqZ7vH41FFRYUz5lLNzc0KBoPOUlNTcy2nDQCw1LS4um/z5s2Kx+PO0tPTk+8pAQCmwDWNVDgcliTFYrGc9bFYzNkWDofV29ubs310dFR9fX3OmEv5/X4FAoGcBQAw813TSM2fP1/hcFitra3OukQiob179yoSiUiSIpGI+vv71dHR4YzZs2ePMpmMli9ffi2nAwCY5jyTvcPg4KB+85vfON8fO3ZMBw8eVEVFhWpra7Vx40Z961vf0i233KL58+friSeeUHV1te6++25J0m233abPfvaz+upXv6rvf//7GhkZ0YYNG/TlL39Z1dXV1+yBAQBmgMleNvjqq68aSZcta9euNcZkL0N/4oknTCgUMn6/36xYscJ0dXXl/IwzZ86Y++67z5SUlJhAIGDuv/9+MzAwcM0vXQQA2Gmiz+MuY4zJYyOvSCKRUDAYVDwe5/UpAJiGJvo8Pi2u7gMA3JiIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1ppUpJqbm7Vs2TKVlpaqqqpKd999t7q6unLGDA8Pq6mpSZWVlSopKdGqVasUi8VyxnR3d6uxsVFFRUWqqqrSo48+qtHR0at/NACAGWVSkWpra1NTU5PefPNNtbS0aGRkRCtXrtTQ0JAz5pFHHtHLL7+sHTt2qK2tTSdOnNA999zjbE+n02psbFQqldIbb7yhF154Qc8//7y2bNly7R4VAGBmMFeht7fXSDJtbW3GGGP6+/uN1+s1O3bscMYcPXrUSDLt7e3GGGN2795t3G63iUajzpht27aZQCBgksnkhH5vPB43kkw8Hr+a6QMA8mSiz+NX9ZpUPB6XJFVUVEiSOjo6NDIyovr6emfMggULVFtbq/b2dklSe3u7br/9doVCIWdMQ0ODEomEDh8+PO7vSSaTSiQSOQsAYOa74khlMhlt3LhRn/rUp7Rw4UJJUjQalc/nU1lZWc7YUCikaDTqjLk4UGPbx7aNp7m5WcFg0FlqamqudNoAgGnkiiPV1NSkQ4cOafv27ddyPuPavHmz4vG4s/T09Fz33wkAyD/Pldxpw4YN2rVrl15//XXNnTvXWR8Oh5VKpdTf359zNBWLxRQOh50x+/bty/l5Y1f/jY25lN/vl9/vv5KpAgCmsUkdSRljtGHDBu3cuVN79uzR/Pnzc7YvWbJEXq9Xra2tzrquri51d3crEolIkiKRiDo7O9Xb2+uMaWlpUSAQUF1d3dU8FgDADDOpI6mmpib98Ic/1E9+8hOVlpY6ryEFg0HNmjVLwWBQ69at06ZNm1RRUaFAIKCHHnpIkUhEd911lyRp5cqVqqur05o1a/TUU08pGo3q8ccfV1NTE0dLAIAcLmOMmfBgl2vc9c8995z+4i/+QlL2j3m//vWv68UXX1QymVRDQ4OeffbZnFN5H3zwgdavX6/XXntNxcXFWrt2rbZu3SqPZ2LNTCQSCgaDisfjCgQCE50+AMASE30en1SkbEGkAGB6m+jzOO/dBwCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYK0reoNZYIwxRspksl8v+rtwl8slXbzow9+xBAA+DJHCFTHGaPTsWQ2++67629s12NWl0XhcymTk8vvlD4VUfOut8ofD8lVVyRsMqqCoSG6/X+7CwuxXv19yuYgXgA9FpHBFBo8eVexHP9JAZ6cyyWTuxmRS5xIJnXvvPWeVq6BA3tmz5QkG5S0vlzcYdG57ysrkLSvL3g4GVVBcTLgASCJSmCSTyWigs1P/+0//pPPvvz/x+6XTSsViSl347LAxLp/POaq6eCmsqZG/qip7JBYOy19VJXdhoVxud/bo68JXjsSAmY1IYcKMMRp67z31/OM/ari7+9r8zFRK6VRK6YGBnPVDR49e9nqWp7xcvqqq38Zrzhx5y8tVUFqqgsJCuWfNUsGsWXLPmiVXQQHxAmYAIoUJMyMjOvEv/3LNAvW7f+FvL8YwkkZOn9bI6dMaOnIkZ5invFye0tLs6cNgMHu7vFzeiorsqcWKiuwSCMhVUDA1cwdwTRApTNjQb36jgc7OfE/jMqNnz2r07NnclQUFcvt8cnu92VOKPp9cXq/8N90kfzgsfyjk3PaUlmbj5XZnj8Dc7uxtjsSAvCNSmLDjzz0nZTL5nsbEpNPKnD+vzPnzOauHP/ggd5zLpYJZs+StqpJvzhz558zJXo1YUSFPIKCC4mIVFBU5X91+P/ECphCRwo3NGKXPnVP6/fc1fMmFIO6iInkCgezpw9LS7O1AQN7KSvkqK7OnECsr5auoyF5OD+CaI1LAh8icO6fUuXNKRaO/XelyyeXxyOX1yn3hq8vjkbe8XP6PfSx7GjEclr+6Wr7ycrl8vuz4goLfLm7e6AWYKCIFTIYxMiMjMiMjuvjEZyoW09C77+YMdXm98l64ItFbWSnf7NnyzZnjHJ0VlJSooLhYnpISFVy4IhFALiIFXCdmZESp3l6lentz1rt8vmyYxiI1djqxoiL7mtjs2dmozZmjguLi396P18JwAyJSmLDqP/9zvffkk9Pn4glLmVRKI319Gunry90wdjpw7PSgxyNvMKjiW29V+R//sYo+8QnejQM3HCKFCSu57TaVLFigwUv+TgnXSDotk07LpFLOqtGzZ3X+/fd15tVXFbzjDlXdfbdKbruNUOGGwSu4mDCX16uPfeUr8n/sY/meyg3HpFLqf/NNdW/bpsFDh7LvOg/cAIgUJszlcql4wQLNXbeOUOXJ8AcfqPsf/kGDR48SKtwQiBQmxeV2q2zpUtV+7WsqXbxYLp8v31OacY4PDWlXT49e/J//0X+eOKGhkZGc7cMffKCTL76o9OBgnmYITB1ek8IVKV24UIXhsAYOHdLZX/1K8bfeym4Y+697/it/0owxOjY4qCcPHND7g4MaTqcV8Hq1sLxc/3fZMnkv+vuqgXfe0bn//u/sfyjw+hRmMJeZhucMEomEgsGg4vG4AoFAvqdzQzOZjMzoqNLnzikZiyl54kR2uXA7PTysTDKZXS7cJmDj+++BAT3wq18pfsmRkyTdOXu2/s8nP6nKwkJnnb+6Wn+wbRuRwrQ00edxjqRwVVxut/MGrt6yMpXcequzzaTTGonHNXr2rFJnzmjk7Nnspdf9/Rrt79dIf3/2+7Nnc65ou1F99/DhcQMlSftOn1bLiRP68u/9nrPOfMhYYCYhUrhuXAUF8lVUyFdRoaJPfEJS9pSWSaWyR1jDw8qcP6/08LBGzp5V8uRJJaNRJU+e1PDx49mPo7/wcR1m7GM7OAoDbihEClPK5XLJdeHTdxUMOutzInQhSunz55WMRpWKRpWMxTR8/LhSsZgyyaTSF97hPD3OO50DmDmIFKzguuhTeCXJJcnt9cobCEi///vO+szoqNKJhFJnz2o0Htdof79Sp05pdGBAI319zmnFVG/vtHtnjMaaGr11+rRGxjlanFdSokUVFXmYFZBfRArTitvjkfvCJ+2OMRfe9NW5QOPCkdbI2bPZ04cXXcwx0t+ffWeHTEYmk8mGzJJTiA3V1ZKkb73zjlLptDKSClwulfl8+n/LlunjJSU540OrVuVhlsDUIlKY9lwul3PxhkpLJSn3D10vuix+dGjIedPX4ZMnlYrFNHLmjEYHBpQ+fz772VJDQ8oMD0/5kZjL5VJDdbXmFhVp1//+r84MD2teSYnunT9flZd8XpX/pptUHolwZR9mPCKFGSnnyfui295AQN5AQMU33+ysy4yOajSRyC7xePaKxHg8e/rw1Kns19OnNdLXd92vQnS5XFpYXq6F5eUfOsZTXq7qNWvk4c8vcAMgUrjhuT0e5yrEMc4pxFQq+/XC6cSRvj4NHz+ePX0YjWr4xInsVYhjbw6bycik09ftKKygpEQ33Xuvyu68k8+fwg2BSAHjyDmFeIExRoVz56p00SJdtFLpoSGlTp1SMhZzTiWmTp9WemhIo4ODSg8OKj00pPS5c1f8+peroECFNTUKfelLqviTP+E0H24YRAqYoHHD4HI5H1pYdOEPbY0xMqOj2de5BgaypxEHBzUajyt16pRSZ85kv546pZHTp2VGRy//uQUFUiajgpISFc6dq+DSpQouXapZtbUECjcUIgVcYy6XSy6vN3v68NJTiKOjOUvmwqf3Jk+c0HA06pxK/NiaNZo1b55cBQVye71yFxXJ7eH/rrjx8K8emCJj8ZLX66wzxsg3Z45K/uAPxh0P3OiIFJBHhAj4aHyeFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrTSpS27Zt06JFixQIBBQIBBSJRPTKK68424eHh9XU1KTKykqVlJRo1apVisViOT+ju7tbjY2NKioqUlVVlR599FGNjo5em0cDAJhRJhWpuXPnauvWrero6NBbb72lz3zmM/riF7+ow4cPS5IeeeQRvfzyy9qxY4fa2tp04sQJ3XPPPc790+m0GhsblUql9MYbb+iFF17Q888/ry1btlzbRwUAmBnMVSovLzc/+MEPTH9/v/F6vWbHjh3OtqNHjxpJpr293RhjzO7du43b7TbRaNQZs23bNhMIBEwymZzw74zH40aSicfjVzt9AEAeTPR5/Ipfk0qn09q+fbuGhoYUiUTU0dGhkZER1dfXO2MWLFig2tpatbe3S5La29t1++23KxQKOWMaGhqUSCSco7HxJJNJJRKJnAUAMPNNOlKdnZ0qKSmR3+/Xgw8+qJ07d6qurk7RaFQ+n09lZWU540OhkKLRqCQpGo3mBGps+9i2D9Pc3KxgMOgsNTU1k502AGAamnSkbr31Vh08eFB79+7V+vXrtXbtWh05cuR6zM2xefNmxeNxZ+np6bmuvw8AYAfPZO/g8/l08803S5KWLFmi/fv363vf+57uvfdepVIp9ff35xxNxWIxhcNhSVI4HNa+fftyft7Y1X9jY8bj9/vl9/snO1UAwDR31X8nlclklEwmtWTJEnm9XrW2tjrburq61N3drUgkIkmKRCLq7OxUb2+vM6alpUWBQEB1dXVXOxUAwAwzqSOpzZs363Of+5xqa2s1MDCgH/7wh3rttdf0s5/9TMFgUOvWrdOmTZtUUVGhQCCghx56SJFIRHfddZckaeXKlaqrq9OaNWv01FNPKRqN6vHHH1dTUxNHSgCAy0wqUr29vfrKV76ikydPKhgMatGiRfrZz36mP/3TP5Ukfec735Hb7daqVauUTCbV0NCgZ5991rl/QUGBdu3apfXr1ysSiai4uFhr167VN7/5zWv7qAAAM4LLGGPyPYnJSiQSCgaDisfjCgQC+Z4OAGCSJvo8znv3AQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALDWVUVq69atcrlc2rhxo7NueHhYTU1NqqysVElJiVatWqVYLJZzv+7ubjU2NqqoqEhVVVV69NFHNTo6ejVTAQDMQFccqf379+vv//7vtWjRopz1jzzyiF5++WXt2LFDbW1tOnHihO655x5nezqdVmNjo1KplN544w298MILev7557Vly5YrfxQAgJnJXIGBgQFzyy23mJaWFvPpT3/aPPzww8YYY/r7+43X6zU7duxwxh49etRIMu3t7cYYY3bv3m3cbreJRqPOmG3btplAIGCSyeSEfn88HjeSTDwev5LpAwDybKLP41d0JNXU1KTGxkbV19fnrO/o6NDIyEjO+gULFqi2tlbt7e2SpPb2dt1+++0KhULOmIaGBiUSCR0+fHjc35dMJpVIJHIWAMDM55nsHbZv3663335b+/fvv2xbNBqVz+dTWVlZzvpQKKRoNOqMuThQY9vHto2nublZ3/jGNyY7VQDANDepI6menh49/PDD+td//VcVFhZerzldZvPmzYrH487S09MzZb8bAJA/k4pUR0eHent7dccdd8jj8cjj8aitrU1PP/20PB6PQqGQUqmU+vv7c+4Xi8UUDoclSeFw+LKr/ca+HxtzKb/fr0AgkLMAAGa+SUVqxYoV6uzs1MGDB51l6dKlWr16tXPb6/WqtbXVuU9XV5e6u7sViUQkSZFIRJ2dnert7XXGtLS0KBAIqK6u7ho9LADATDCp16RKS0u1cOHCnHXFxcWqrKx01q9bt06bNm1SRUWFAoGAHnroIUUiEd11112SpJUrV6qurk5r1qzRU089pWg0qscff1xNTU3y+/3X6GEBAGaCSV848bt85zvfkdvt1qpVq5RMJtXQ0KBnn33W2V5QUKBdu3Zp/fr1ikQiKi4u1tq1a/XNb37zWk8FADDNuYwxJt+TmKxEIqFgMKh4PM7rUwAwDU30eZz37gMAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWMuT7wlcCWOMJCmRSOR5JgCAKzH2/D32fP5hpmWkzpw5I0mqqanJ80wAAFdjYGBAwWDwQ7dPy0hVVFRIkrq7uz/ywd3oEomEampq1NPTo0AgkO/pWIv9NDHsp4lhP02MMUYDAwOqrq7+yHHTMlJud/altGAwyD+CCQgEAuynCWA/TQz7aWLYT7/bRA4yuHACAGAtIgUAsNa0jJTf79eTTz4pv9+f76lYjf00MeyniWE/TQz76dpymd91/R8AAHkyLY+kAAA3BiIFALAWkQIAWItIAQCsNS0j9cwzz2jevHkqLCzU8uXLtW/fvnxPaUq9/vrr+vznP6/q6mq5XC699NJLOduNMdqyZYtuuukmzZo1S/X19XrvvfdyxvT19Wn16tUKBAIqKyvTunXrNDg4OIWP4vpqbm7WsmXLVFpaqqqqKt19993q6urKGTM8PKympiZVVlaqpKREq1atUiwWyxnT3d2txsZGFRUVqaqqSo8++qhGR0en8qFcV9u2bdOiRYucPzyNRCJ65ZVXnO3so/Ft3bpVLpdLGzdudNaxr64TM81s377d+Hw+88///M/m8OHD5qtf/aopKyszsVgs31ObMrt37zZ//dd/bX784x8bSWbnzp0527du3WqCwaB56aWXzDvvvGO+8IUvmPnz55vz5887Yz772c+axYsXmzfffNP84he/MDfffLO57777pviRXD8NDQ3mueeeM4cOHTIHDx40f/Znf2Zqa2vN4OCgM+bBBx80NTU1prW11bz11lvmrrvuMn/0R3/kbB8dHTULFy409fX15sCBA2b37t1m9uzZZvPmzfl4SNfFv//7v5v/+I//MP/1X/9lurq6zF/91V8Zr9drDh06ZIxhH41n3759Zt68eWbRokXm4Ycfdtazr66PaRepO++80zQ1NTnfp9NpU11dbZqbm/M4q/y5NFKZTMaEw2Hz7W9/21nX399v/H6/efHFF40xxhw5csRIMvv373fGvPLKK8blcpnjx49P2dynUm9vr5Fk2trajDHZfeL1es2OHTucMUePHjWSTHt7uzEm+x8DbrfbRKNRZ8y2bdtMIBAwyWRyah/AFCovLzc/+MEP2EfjGBgYMLfccotpaWkxn/70p51Isa+un2l1ui+VSqmjo0P19fXOOrfbrfr6erW3t+dxZvY4duyYotFozj4KBoNavny5s4/a29tVVlampUuXOmPq6+vldru1d+/eKZ/zVIjH45J+++bEHR0dGhkZydlPCxYsUG1tbc5+uv322xUKhZwxDQ0NSiQSOnz48BTOfmqk02lt375dQ0NDikQi7KNxNDU1qbGxMWefSPx7up6m1RvMnj59Wul0Oud/ZEkKhUJ699138zQru0SjUUkadx+NbYtGo6qqqsrZ7vF4VFFR4YyZSTKZjDZu3KhPfepTWrhwoaTsPvD5fCorK8sZe+l+Gm8/jm2bKTo7OxWJRDQ8PKySkhLt3LlTdXV1OnjwIPvoItu3b9fbb7+t/fv3X7aNf0/Xz7SKFHAlmpqadOjQIf3yl7/M91SsdOutt+rgwYOKx+P60Y9+pLVr16qtrS3f07JKT0+PHn74YbW0tKiwsDDf07mhTKvTfbNnz1ZBQcFlV8zEYjGFw+E8zcouY/vho/ZROBxWb29vzvbR0VH19fXNuP24YcMG7dq1S6+++qrmzp3rrA+Hw0qlUurv788Zf+l+Gm8/jm2bKXw+n26++WYtWbJEzc3NWrx4sb73ve+xjy7S0dGh3t5e3XHHHfJ4PPJ4PGpra9PTTz8tj8ejUCjEvrpOplWkfD6flixZotbWVmddJpNRa2urIpFIHmdmj/nz5yscDufso0Qiob179zr7KBKJqL+/Xx0dHc6YPXv2KJPJaPny5VM+5+vBGKMNGzZo586d2rNnj+bPn5+zfcmSJfJ6vTn7qaurS93d3Tn7qbOzMyfoLS0tCgQCqqurm5oHkgeZTEbJZJJ9dJEVK1aos7NTBw8edJalS5dq9erVzm321XWS7ys3Jmv79u3G7/eb559/3hw5csQ88MADpqysLOeKmZluYGDAHDhwwBw4cMBIMn/zN39jDhw4YD744ANjTPYS9LKyMvOTn/zE/PrXvzZf/OIXx70E/ZOf/KTZu3ev+eUvf2luueWWGXUJ+vr1600wGDSvvfaaOXnypLOcO3fOGfPggw+a2tpas2fPHvPWW2+ZSCRiIpGIs33skuGVK1eagwcPmp/+9Kdmzpw5M+qS4ccee8y0tbWZY8eOmV//+tfmscceMy6Xy/z85z83xrCPPsrFV/cZw766XqZdpIwx5m//9m9NbW2t8fl85s477zRvvvlmvqc0pV599VUj6bJl7dq1xpjsZehPPPGECYVCxu/3mxUrVpiurq6cn3HmzBlz3333mZKSEhMIBMz9999vBgYG8vBoro/x9o8k89xzzzljzp8/b772ta+Z8vJyU1RUZL70pS+ZkydP5vyc999/33zuc58zs2bNMrNnzzZf//rXzcjIyBQ/muvnL//yL83HP/5x4/P5zJw5c8yKFSucQBnDPvool0aKfXV98FEdAABrTavXpAAANxYiBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArPX/ASQB1QkkPZB4AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这个游戏的状态用3个数字表示,我也不知道这3个数字分别是什么意思,反正这3个数字就能描述游戏全部的状态\n",
      "state= [-0.99827415  0.0587259  -0.65752804]\n",
      "这个游戏的动作是个-2到+2之间的连续值\n",
      "env.action_space= Box(-2.0, 2.0, (1,), float32)\n",
      "随机一个动作\n",
      "action= [1.9444304]\n",
      "执行一个动作,得到下一个状态,奖励,是否结束\n",
      "state= [-0.9972      0.07478078 -0.32181907]\n",
      "reward= -9.55087409733753\n",
      "over= False\n"
     ]
    }
   ],
   "source": [
    "#测试游戏环境\n",
    "def test_env():\n",
    "    state = env.reset()\n",
    "    print('这个游戏的状态用3个数字表示,我也不知道这3个数字分别是什么意思,反正这3个数字就能描述游戏全部的状态')\n",
    "    print('state=', state)\n",
    "    #state= [-0.91304934 -0.40784913  0.271098  ]\n",
    "\n",
    "    print('这个游戏的动作是个-2到+2之间的连续值')\n",
    "    print('env.action_space=', env.action_space)\n",
    "    #env.action_space= Box(-2.0, 2.0, (1,), float32)\n",
    "\n",
    "    print('随机一个动作')\n",
    "    action = env.action_space.sample()\n",
    "    print('action=', action)\n",
    "    #action= [-0.14946985]\n",
    "\n",
    "    print('执行一个动作,得到下一个状态,奖励,是否结束')\n",
    "    state, reward, over, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    #state= [-0.5629868  0.8264659  2.7232552]\n",
    "\n",
    "    print('reward=', reward)\n",
    "    #reward= -4.456876123969679\n",
    "\n",
    "    print('over=', over)\n",
    "    #over= False\n",
    "\n",
    "\n",
    "test_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7, 0.7999999999999998)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action = model(state).argmax().item()\n",
    "\n",
    "    if random.random() < 0.01:\n",
    "        action = random.choice(range(11))\n",
    "\n",
    "    #离散动作连续化\n",
    "    action_continuous = action\n",
    "    action_continuous /= 10\n",
    "    action_continuous *= 4\n",
    "    action_continuous -= 2\n",
    "\n",
    "    return action, action_continuous\n",
    "\n",
    "\n",
    "get_action([0.29292667, 0.9561349, 1.0957013])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200, 0), 200)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action, action_continuous = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action_continuous])\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 5000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 5000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1150/1416897299.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-6.3054e-01, -7.7616e-01, -5.4079e+00],\n",
       "         [ 6.3669e-01, -7.7112e-01, -5.0849e+00],\n",
       "         [ 9.8729e-01, -1.5890e-01, -1.3218e+00],\n",
       "         [-6.2421e-01, -7.8126e-01,  6.1030e+00],\n",
       "         [ 8.4212e-01,  5.3929e-01,  2.3585e-01],\n",
       "         [ 9.0119e-01,  4.3342e-01, -4.0446e+00],\n",
       "         [-1.8082e-01, -9.8352e-01, -3.7430e+00],\n",
       "         [-9.2372e-01, -3.8307e-01, -6.6561e+00],\n",
       "         [ 3.2404e-01, -9.4604e-01,  1.5120e+00],\n",
       "         [-9.9148e-01,  1.3026e-01, -8.0000e+00],\n",
       "         [ 1.5388e-01, -9.8809e-01,  2.8590e+00],\n",
       "         [ 8.2102e-01,  5.7090e-01,  7.6032e-01],\n",
       "         [ 9.5944e-01,  2.8192e-01, -9.1941e-01],\n",
       "         [ 8.4711e-01,  5.3142e-01, -4.1749e+00],\n",
       "         [ 7.4245e-01, -6.6990e-01, -3.1891e+00],\n",
       "         [ 7.9007e-01,  6.1302e-01, -4.3961e+00],\n",
       "         [ 7.1286e-02, -9.9746e-01, -6.5938e+00],\n",
       "         [ 4.2832e-01,  9.0363e-01, -4.0346e+00],\n",
       "         [ 7.9577e-01,  6.0559e-01, -4.3788e+00],\n",
       "         [ 7.3450e-01,  6.7861e-01, -1.9576e+00],\n",
       "         [-8.9955e-01, -4.3682e-01, -8.0000e+00],\n",
       "         [-2.4498e-01, -9.6953e-01,  4.5460e+00],\n",
       "         [ 8.3711e-01,  5.4704e-01, -7.2971e-01],\n",
       "         [-9.9865e-01, -5.2034e-02, -8.0000e+00],\n",
       "         [-2.7554e-01, -9.6129e-01, -4.4352e+00],\n",
       "         [ 5.8378e-01,  8.1191e-01,  2.2394e+00],\n",
       "         [-9.4008e-01,  3.4096e-01, -8.0000e+00],\n",
       "         [ 2.6858e-01,  9.6326e-01, -4.2918e+00],\n",
       "         [ 5.7299e-02,  9.9836e-01, -4.9806e+00],\n",
       "         [ 9.9996e-01, -8.7358e-03, -3.7355e+00],\n",
       "         [ 5.6353e-02, -9.9841e-01, -6.6336e+00],\n",
       "         [ 2.4697e-01, -9.6902e-01, -1.1645e+00],\n",
       "         [-7.8222e-01,  6.2300e-01, -7.0370e+00],\n",
       "         [ 5.7525e-03,  9.9998e-01, -5.3929e+00],\n",
       "         [ 7.2670e-01,  6.8696e-01,  9.2797e-01],\n",
       "         [-3.9335e-01, -9.1939e-01,  5.3970e+00],\n",
       "         [ 1.7883e-02,  9.9984e-01, -6.5013e+00],\n",
       "         [ 9.6515e-01,  2.6171e-01, -3.8453e+00],\n",
       "         [-4.5717e-01, -8.8938e-01,  5.3331e+00],\n",
       "         [-2.6084e-01,  9.6538e-01, -6.0570e+00],\n",
       "         [ 6.2577e-01,  7.8001e-01,  2.5369e+00],\n",
       "         [ 5.3651e-01,  8.4389e-01, -5.1568e+00],\n",
       "         [ 1.4843e-01, -9.8892e-01, -2.0113e+00],\n",
       "         [ 9.7028e-01,  2.4197e-01, -8.2797e-01],\n",
       "         [ 9.0296e-01,  4.2972e-01, -1.4069e+00],\n",
       "         [ 8.2684e-01,  5.6244e-01, -1.9579e+00],\n",
       "         [ 4.5730e-01,  8.8931e-01,  2.9683e+00],\n",
       "         [-4.7264e-01, -8.8126e-01, -7.8168e+00],\n",
       "         [ 4.9168e-01,  8.7077e-01,  3.2419e+00],\n",
       "         [-9.4573e-01,  3.2494e-01, -8.0000e+00],\n",
       "         [ 2.9476e-01, -9.5557e-01,  9.7781e-01],\n",
       "         [-8.8938e-01,  4.5717e-01,  6.5727e+00],\n",
       "         [-7.8421e-01, -6.2049e-01, -7.8319e+00],\n",
       "         [ 2.5166e-01, -9.6782e-01,  1.9979e+00],\n",
       "         [ 8.2750e-02,  9.9657e-01,  4.6085e+00],\n",
       "         [-4.5436e-01,  8.9082e-01,  5.8540e+00],\n",
       "         [-6.9423e-01,  7.1975e-01,  6.1528e+00],\n",
       "         [ 3.1040e-01,  9.5060e-01, -5.7997e+00],\n",
       "         [ 9.8306e-01,  1.8328e-01, -3.7442e+00],\n",
       "         [ 7.1421e-01,  6.9994e-01, -4.5839e+00],\n",
       "         [ 5.2376e-01, -8.5187e-01, -5.3444e+00],\n",
       "         [ 3.0294e-01, -9.5301e-01, -3.2978e-01],\n",
       "         [ 3.0817e-01,  9.5133e-01,  4.0150e+00],\n",
       "         [-1.7401e-01, -9.8474e-01,  4.5875e+00]]),\n",
       " tensor([[3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [7],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [9],\n",
       "         [3],\n",
       "         [3],\n",
       "         [7],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [4],\n",
       "         [3],\n",
       "         [4],\n",
       "         [3],\n",
       "         [3],\n",
       "         [4],\n",
       "         [3],\n",
       "         [3],\n",
       "         [7],\n",
       "         [3],\n",
       "         [4],\n",
       "         [4],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [4],\n",
       "         [4],\n",
       "         [7],\n",
       "         [3],\n",
       "         [4],\n",
       "         [3],\n",
       "         [3],\n",
       "         [4],\n",
       "         [7],\n",
       "         [4],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [7],\n",
       "         [3],\n",
       "         [7],\n",
       "         [3],\n",
       "         [9],\n",
       "         [3],\n",
       "         [3],\n",
       "         [9],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [4],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3],\n",
       "         [3]]),\n",
       " tensor([[ -8.0014],\n",
       "         [ -3.3618],\n",
       "         [ -0.2008],\n",
       "         [ -8.7649],\n",
       "         [ -0.3306],\n",
       "         [ -1.8375],\n",
       "         [ -4.4733],\n",
       "         [-11.9852],\n",
       "         [ -1.7708],\n",
       "         [-15.4665],\n",
       "         [ -2.8239],\n",
       "         [ -0.4276],\n",
       "         [ -0.1669],\n",
       "         [ -2.0575],\n",
       "         [ -1.5565],\n",
       "         [ -2.3686],\n",
       "         [ -6.5968],\n",
       "         [ -2.9007],\n",
       "         [ -2.3412],\n",
       "         [ -0.9397],\n",
       "         [-13.6342],\n",
       "         [ -5.3735],\n",
       "         [ -0.3884],\n",
       "         [-15.9459],\n",
       "         [ -5.3900],\n",
       "         [ -1.3997],\n",
       "         [-14.2051],\n",
       "         [ -3.5292],\n",
       "         [ -4.7714],\n",
       "         [ -1.3961],\n",
       "         [ -6.6946],\n",
       "         [ -1.8820],\n",
       "         [-11.0481],\n",
       "         [ -5.3579],\n",
       "         [ -0.6602],\n",
       "         [ -6.8143],\n",
       "         [ -6.6384],\n",
       "         [ -1.5494],\n",
       "         [ -7.0293],\n",
       "         [ -7.0349],\n",
       "         [ -1.4447],\n",
       "         [ -3.6684],\n",
       "         [ -2.4267],\n",
       "         [ -0.1289],\n",
       "         [ -0.3959],\n",
       "         [ -0.7408],\n",
       "         [ -2.0826],\n",
       "         [-10.3672],\n",
       "         [ -2.1684],\n",
       "         [-14.3004],\n",
       "         [ -1.7151],\n",
       "         [-11.4324],\n",
       "         [-12.2464],\n",
       "         [ -2.1346],\n",
       "         [ -4.3385],\n",
       "         [ -7.5991],\n",
       "         [ -9.2533],\n",
       "         [ -4.9392],\n",
       "         [ -1.4365],\n",
       "         [ -2.7029],\n",
       "         [ -3.8963],\n",
       "         [ -1.6067],\n",
       "         [ -3.1941],\n",
       "         [ -5.1526]]),\n",
       " tensor([[-8.3479e-01, -5.5057e-01, -6.1100e+00],\n",
       "         [ 3.9037e-01, -9.2066e-01, -5.7833e+00],\n",
       "         [ 9.7190e-01, -2.3540e-01, -1.5610e+00],\n",
       "         [-3.9335e-01, -9.1939e-01,  5.3970e+00],\n",
       "         [ 8.2102e-01,  5.7090e-01,  7.6032e-01],\n",
       "         [ 9.6733e-01,  2.5351e-01, -3.8395e+00],\n",
       "         [-4.0030e-01, -9.1638e-01, -4.6006e+00],\n",
       "         [-9.9920e-01, -3.9937e-02, -7.0634e+00],\n",
       "         [ 3.7289e-01, -9.2788e-01,  1.0425e+00],\n",
       "         [-8.6249e-01,  5.0608e-01, -8.0000e+00],\n",
       "         [ 2.5166e-01, -9.6782e-01,  1.9979e+00],\n",
       "         [ 7.8194e-01,  6.2336e-01,  1.3085e+00],\n",
       "         [ 9.7028e-01,  2.4197e-01, -8.2797e-01],\n",
       "         [ 9.3396e-01,  3.5737e-01, -3.8964e+00],\n",
       "         [ 6.0211e-01, -7.9841e-01, -3.8115e+00],\n",
       "         [ 8.9735e-01,  4.4131e-01, -4.0563e+00],\n",
       "         [-2.9719e-01, -9.5482e-01, -7.4619e+00],\n",
       "         [ 5.7571e-01,  8.1765e-01, -3.4169e+00],\n",
       "         [ 9.0119e-01,  4.3342e-01, -4.0446e+00],\n",
       "         [ 7.8355e-01,  6.2133e-01, -1.5086e+00],\n",
       "         [-9.9865e-01, -5.2034e-02, -8.0000e+00],\n",
       "         [-6.2517e-02, -9.9804e-01,  3.6989e+00],\n",
       "         [ 8.4733e-01,  5.3106e-01, -3.7943e-01],\n",
       "         [-9.4008e-01,  3.4096e-01, -8.0000e+00],\n",
       "         [-5.1667e-01, -8.5618e-01, -5.2761e+00],\n",
       "         [ 4.5730e-01,  8.8931e-01,  2.9683e+00],\n",
       "         [-7.3769e-01,  6.7514e-01, -7.8643e+00],\n",
       "         [ 4.3802e-01,  8.9897e-01, -3.6294e+00],\n",
       "         [ 2.6858e-01,  9.6326e-01, -4.2918e+00],\n",
       "         [ 9.7970e-01, -2.0047e-01, -3.8620e+00],\n",
       "         [-3.1337e-01, -9.4963e-01, -7.5024e+00],\n",
       "         [ 1.4843e-01, -9.8892e-01, -2.0113e+00],\n",
       "         [-5.3689e-01,  8.4365e-01, -6.6297e+00],\n",
       "         [ 2.3858e-01,  9.7112e-01, -4.7029e+00],\n",
       "         [ 6.7084e-01,  7.4160e-01,  1.5632e+00],\n",
       "         [-1.7401e-01, -9.8474e-01,  4.5875e+00],\n",
       "         [ 3.0359e-01,  9.5280e-01, -5.8114e+00],\n",
       "         [ 9.9709e-01,  7.6270e-02, -3.7690e+00],\n",
       "         [-2.4498e-01, -9.6953e-01,  4.5460e+00],\n",
       "         [ 5.7525e-03,  9.9998e-01, -5.3929e+00],\n",
       "         [ 4.9168e-01,  8.7077e-01,  3.2419e+00],\n",
       "         [ 7.1421e-01,  6.9994e-01, -4.5839e+00],\n",
       "         [ 5.3321e-03, -9.9999e-01, -2.8730e+00],\n",
       "         [ 9.7884e-01,  2.0462e-01, -7.6649e-01],\n",
       "         [ 9.2719e-01,  3.7459e-01, -1.2046e+00],\n",
       "         [ 8.7052e-01,  4.9213e-01, -1.6560e+00],\n",
       "         [ 2.8326e-01,  9.5904e-01,  3.7553e+00],\n",
       "         [-7.7851e-01, -6.2764e-01, -8.0000e+00],\n",
       "         [ 3.0817e-01,  9.5133e-01,  4.0150e+00],\n",
       "         [-7.4866e-01,  6.6296e-01, -7.8763e+00],\n",
       "         [ 3.1861e-01, -9.4789e-01,  5.0113e-01],\n",
       "         [-9.9090e-01,  1.3462e-01,  6.7955e+00],\n",
       "         [-9.6394e-01, -2.6613e-01, -8.0000e+00],\n",
       "         [ 3.2404e-01, -9.4604e-01,  1.5120e+00],\n",
       "         [-1.7800e-01,  9.8403e-01,  5.2360e+00],\n",
       "         [-7.1159e-01,  7.0260e-01,  6.4021e+00],\n",
       "         [-8.8938e-01,  4.5717e-01,  6.5727e+00],\n",
       "         [ 5.4212e-01,  8.4030e-01, -5.1467e+00],\n",
       "         [ 1.0000e+00, -2.0131e-03, -3.7268e+00],\n",
       "         [ 8.4386e-01,  5.3657e-01, -4.1789e+00],\n",
       "         [ 2.4362e-01, -9.6987e-01, -6.1033e+00],\n",
       "         [ 2.4697e-01, -9.6902e-01, -1.1645e+00],\n",
       "         [ 8.2750e-02,  9.9657e-01,  4.6085e+00],\n",
       "         [ 1.1550e-02, -9.9993e-01,  3.7289e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.8971],\n",
       "        [1.8895],\n",
       "        [0.4894],\n",
       "        [1.4216],\n",
       "        [0.2038],\n",
       "        [1.4672],\n",
       "        [1.3804],\n",
       "        [2.2540],\n",
       "        [0.4402],\n",
       "        [2.6726],\n",
       "        [0.6657],\n",
       "        [0.3417],\n",
       "        [0.3947],\n",
       "        [1.5066],\n",
       "        [1.2022],\n",
       "        [1.5774],\n",
       "        [2.3895],\n",
       "        [1.5202],\n",
       "        [1.5721],\n",
       "        [0.7820],\n",
       "        [2.7364],\n",
       "        [1.0447],\n",
       "        [0.3909],\n",
       "        [2.6844],\n",
       "        [1.6178],\n",
       "        [0.5999],\n",
       "        [2.6650],\n",
       "        [1.6425],\n",
       "        [1.8870],\n",
       "        [1.3840],\n",
       "        [2.4019],\n",
       "        [0.4456],\n",
       "        [2.4518],\n",
       "        [2.0121],\n",
       "        [0.3733],\n",
       "        [1.2492],\n",
       "        [2.3282],\n",
       "        [1.4077],\n",
       "        [1.2309],\n",
       "        [2.2285],\n",
       "        [0.6557],\n",
       "        [1.8287],\n",
       "        [0.7379],\n",
       "        [0.3664],\n",
       "        [0.5582],\n",
       "        [0.7446],\n",
       "        [0.7122],\n",
       "        [2.7538],\n",
       "        [0.7653],\n",
       "        [2.6653],\n",
       "        [0.3022],\n",
       "        [1.5530],\n",
       "        [2.7063],\n",
       "        [0.5125],\n",
       "        [1.0812],\n",
       "        [1.3949],\n",
       "        [1.4605],\n",
       "        [2.0797],\n",
       "        [1.3770],\n",
       "        [1.6332],\n",
       "        [1.9794],\n",
       "        [0.2887],\n",
       "        [0.9352],\n",
       "        [1.0577]], grad_fn=<GatherBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 3] -> [b, 11]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 11] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-5.9468e+00],\n",
       "        [-1.2753e+00],\n",
       "        [ 3.5574e-01],\n",
       "        [-7.5407e+00],\n",
       "        [ 4.2747e-03],\n",
       "        [-4.5944e-01],\n",
       "        [-2.8450e+00],\n",
       "        [-9.6899e+00],\n",
       "        [-1.4461e+00],\n",
       "        [-1.2837e+01],\n",
       "        [-2.3216e+00],\n",
       "        [ 3.2480e-02],\n",
       "        [ 1.9217e-01],\n",
       "        [-6.6655e-01],\n",
       "        [-1.5564e-01],\n",
       "        [-9.2717e-01],\n",
       "        [-3.9957e+00],\n",
       "        [-1.6398e+00],\n",
       "        [-9.0329e-01],\n",
       "        [-3.2868e-01],\n",
       "        [-1.1004e+01],\n",
       "        [-4.5419e+00],\n",
       "        [-1.1522e-01],\n",
       "        [-1.3334e+01],\n",
       "        [-3.5535e+00],\n",
       "        [-7.0180e-01],\n",
       "        [-1.1569e+01],\n",
       "        [-2.1566e+00],\n",
       "        [-3.1618e+00],\n",
       "        [ 8.9841e-03],\n",
       "        [-4.0822e+00],\n",
       "        [-1.1588e+00],\n",
       "        [-8.7083e+00],\n",
       "        [-3.6242e+00],\n",
       "        [-1.7181e-01],\n",
       "        [-5.7777e+00],\n",
       "        [-4.5954e+00],\n",
       "        [-1.8650e-01],\n",
       "        [-6.0055e+00],\n",
       "        [-5.0630e+00],\n",
       "        [-6.9468e-01],\n",
       "        [-2.0679e+00],\n",
       "        [-1.3744e+00],\n",
       "        [ 2.1083e-01],\n",
       "        [ 8.2662e-02],\n",
       "        [-1.0539e-01],\n",
       "        [-1.2329e+00],\n",
       "        [-7.6562e+00],\n",
       "        [-1.2519e+00],\n",
       "        [-1.1664e+01],\n",
       "        [-1.4882e+00],\n",
       "        [-9.8604e+00],\n",
       "        [-9.5911e+00],\n",
       "        [-1.7032e+00],\n",
       "        [-3.1228e+00],\n",
       "        [-6.1048e+00],\n",
       "        [-7.7314e+00],\n",
       "        [-3.1517e+00],\n",
       "        [-8.3649e-02],\n",
       "        [-1.2255e+00],\n",
       "        [-1.7098e+00],\n",
       "        [-1.1701e+00],\n",
       "        [-2.1345e+00],\n",
       "        [-4.3102e+00]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 11] -> [b, 1]\n",
    "    target = target.max(dim=1)[0]\n",
    "    target = target.reshape(-1, 1)\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1821.5113474840912"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        _, action_continuous = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action_continuous])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 200 0 -1551.6725573883625\n",
      "20 4400 200 0 -1248.8182808178515\n",
      "40 5000 200 200 -807.8880145343439\n",
      "60 5000 200 200 -1003.3314484824092\n",
      "80 5000 200 200 -232.94804696095198\n",
      "100 5000 200 200 -168.23735615761436\n",
      "120 5000 200 200 -166.3879498836516\n",
      "140 5000 200 200 -289.496022713074\n",
      "160 5000 200 200 -565.0722293827679\n",
      "180 5000 200 200 -234.31696558612794\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 50 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmXklEQVR4nO3df3DU9YH/8dfuZrP5xW4ImIRIIrRYkQFUQGHb79UqOYIXvVpxzlpGKXJ2pMEBuXEqp+JX72bi0VZb7xA716k4d6fc0Tv8QaG9FDToGfkRiSA/cnYOL8GwCT+a3SQmm+zu+/uHsl8XgiZkk31nfT5mdqb5fN772fd+GvbpZ/eTzzqMMUYAAFjImeoJAABwIUQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGCtlEVq/fr1mjRpkrKysjR37lzt2bMnVVMBAFgqJZH613/9V61evVqPPfaY3n33XV111VWqqKhQW1tbKqYDALCUIxUXmJ07d66uvfZa/cM//IMkKRaLqbS0VPfff78eeuihkZ4OAMBSGSP9gL29vaqvr9eaNWviy5xOp8rLy1VXV9fvfcLhsMLhcPznWCymM2fOaNy4cXI4HMM+ZwBAchlj1NHRoZKSEjmdF35Tb8QjderUKUWjURUVFSUsLyoq0tGjR/u9T3V1tR5//PGRmB4AYAQ1Nzdr4sSJF1w/4pG6GGvWrNHq1avjPweDQZWVlam5uVlerzeFMwMAXIxQKKTS0lKNGTPmc8eNeKTGjx8vl8ul1tbWhOWtra0qLi7u9z4ej0cej+e85V6vl0gBwCj2RR/ZjPjZfZmZmZo9e7Z27NgRXxaLxbRjxw75/f6Rng4AwGIpebtv9erVWrJkiebMmaPrrrtOP/vZz9TV1aWlS5emYjoAAEulJFJ33HGHTp48qbVr1yoQCOjqq6/Wb3/72/NOpgAAfLml5O+khioUCsnn8ykYDPKZFACMQgN9HefafQAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsNehI7dq1S7fccotKSkrkcDj08ssvJ6w3xmjt2rWaMGGCsrOzVV5erg8++CBhzJkzZ7R48WJ5vV7l5+dr2bJl6uzsHNITAQCkn0FHqqurS1dddZXWr1/f7/p169bpmWee0XPPPafdu3crNzdXFRUV6unpiY9ZvHixDh06pJqaGm3dulW7du3SD37wg4t/FgCA9GSGQJLZsmVL/OdYLGaKi4vNj3/84/iy9vZ24/F4zEsvvWSMMebw4cNGktm7d298zPbt243D4TAfffTRgB43GAwaSSYYDA5l+gCAFBno63hSP5M6duyYAoGAysvL48t8Pp/mzp2ruro6SVJdXZ3y8/M1Z86c+Jjy8nI5nU7t3r273+2Gw2GFQqGEGwAg/SU1UoFAQJJUVFSUsLyoqCi+LhAIqLCwMGF9RkaGCgoK4mPOVV1dLZ/PF7+VlpYmc9oAAEuNirP71qxZo2AwGL81NzenekoAgBGQ1EgVFxdLklpbWxOWt7a2xtcVFxerra0tYX0kEtGZM2fiY87l8Xjk9XoTbgCA9JfUSE2ePFnFxcXasWNHfFkoFNLu3bvl9/slSX6/X+3t7aqvr4+P2blzp2KxmObOnZvM6QAARrmMwd6hs7NTf/jDH+I/Hzt2TA0NDSooKFBZWZlWrVqlv/3bv9Xll1+uyZMn69FHH1VJSYluvfVWSdKVV16phQsX6t5779Vzzz2nvr4+rVixQt/97ndVUlKStCcGAEgDgz1t8PXXXzeSzrstWbLEGPPJaeiPPvqoKSoqMh6Px8yfP980NjYmbOP06dPmzjvvNHl5ecbr9ZqlS5eajo6OpJ+6CACw00Bfxx3GGJPCRl6UUCgkn8+nYDDI51MAMAoN9HV8VJzdBwD4ciJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrDfoCswBGlonF1N3UpO5jx9R7+rRMX59cOTnKuvRS5Xzta3JzaTCkMSIFWMrEYuo9dUptr76q4N69ioRCinZ3S7GYHG63XLm5yrr0Uo2vqNBYv19OjyfVUwaSjkgBFjKxmELvvaePNm5U94cfSudcB9r09SnS3q7O9nZ1HjmizqNHVfLd78qdn5+S+QLDhUgBFupqbNTxf/xH9Rw//sWDYzGdrqmRw+HQhO9+V26fb/gnCIwQTpwALNN78qQ++qd/GligPmX6+nSqpkZnamtlotFhnB0wsogUYBETiejktm3qfP/9wd+3t1fHf/Ur9Z4+PQwzA1KDSAEWCbe26o+7d1/8BmIxtW7ZkrwJASlGpACLRDs7FR7E23z96TxyJEmzAVKPSAGWMMYoFomkehqAVYgUYJFoR0eqpwBYhUgBFomEQqmeAmAVIgXYwhgiBZyDSAEWIVJAIiIF2CIW0+mdO1M9C8AqRAqwSKy3N9VTAKxCpIA0kz93bqqnACQNkQLSicOh7EmTUj0LIGmIFGAJE4slZTsZeXlJ2Q5gAyIFWCLa3X3e90ZdDBeRQhohUoAlIh0dMkONlMOhDL5OHmmESAGWiCTpkkjOrKykbAewAZECLBH54x+lJH0uBaQLIgVY4lRNjQxXQQcSECnAFkk4aQJIN0QKAGAtIgWkEU9RkRxO/lkjffDbDKQR7+zZcmRkpHoaQNIQKcACsb4+mWh0yNvJyM2Vw+FIwowAOxApwAKxnh7F+vqGvB1XXp5EpJBGiBRggWhPj0wyIpWbS6SQVogUYIFYd3dSvkuKSCHdECnAAuETJxQJBoe8HYfLxWdSSCtECrBAuKVFkfb2VE8DsA6RAgBYi0gBAKxFpIB0wWdRSENECkgxY8zQv+xQUu7UqcoqLU3CjAB7ECkgxUw0qujHHw95O66sLDkzM5MwI8AeRApItSRFyunxyOl2J2FCgD2IFJBiJhJRtKtryNtxZmXJwZEU0gyRAlIsWW/3cSSFdESkgBTrCwbVceDAkLfjyMiQ+C4ppBl+o4FUi0YV6+5Oyqa4JBLSDZECAFiLSAEArEWkgBRK1h/yAumKSAEpFk3C51HO7GzlfPWrSZgNYBciBaRYJBQa8jacWVnKmjgxCbMB7EKkgBRLypcdOp1y5eQkYTaAXYgUkGKRjo4hb8PhcsmVnZ2E2QB2IVJAioVPnBjyNhwul5wcSSENESkglYxR+zvvDH07TidHUkhLRApIEw4uiYQ0NKjf6urqal177bUaM2aMCgsLdeutt6qxsTFhTE9Pj6qqqjRu3Djl5eVp0aJFam1tTRjT1NSkyspK5eTkqLCwUA8++KAikcjQnw0AIK0MKlK1tbWqqqrSO++8o5qaGvX19WnBggXq+szXDDzwwAN67bXXtHnzZtXW1qqlpUW33XZbfH00GlVlZaV6e3v19ttv64UXXtDGjRu1du3a5D0rAEBacJgh/Ln7yZMnVVhYqNraWn3zm99UMBjUJZdcohdffFG33367JOno0aO68sorVVdXp3nz5mn79u26+eab1dLSoqKiIknSc889px/96Ec6efKkMgfwfTihUEg+n0/BYFBer/dipw+kXCwc1oGlSxXt7BzSdjyXXqrpGzYkaVbA8Bvo6/iQ3sQOfvr3HQUFBZKk+vp69fX1qby8PD5m6tSpKisrU11dnSSprq5OM2bMiAdKkioqKhQKhXTo0KF+HyccDisUCiXcgHQQ6eyUknBZpMxx45IwG8A+Fx2pWCymVatW6Rvf+IamT58uSQoEAsrMzFR+fn7C2KKiIgUCgfiYzwbq7Pqz6/pTXV0tn88Xv5WWll7stAGrRDo7ZWKxIW8n/+tfT8JsAPtcdKSqqqr0/vvva9OmTcmcT7/WrFmjYDAYvzU3Nw/7YwIjIdrRkZQjqYwxY5IwG8A+GRdzpxUrVmjr1q3atWuXJn7memHFxcXq7e1Ve3t7wtFUa2uriouL42P27NmTsL2zZ/+dHXMuj8cjj8dzMVMFrNbd1CQTjQ55Oxl8Nos0NagjKWOMVqxYoS1btmjnzp2aPHlywvrZs2fL7XZrx44d8WWNjY1qamqS3++XJPn9fh08eFBtbW3xMTU1NfJ6vZo2bdpQngsw6nS8955MX9+Qt+PKy0vCbAD7DOpIqqqqSi+++KJeeeUVjRkzJv4Zks/nU3Z2tnw+n5YtW6bVq1eroKBAXq9X999/v/x+v+bNmydJWrBggaZNm6a77rpL69atUyAQ0COPPKKqqiqOloCLlEGkkKYGFakNn57i+q1vfSth+fPPP6/vf//7kqSnn35aTqdTixYtUjgcVkVFhZ599tn4WJfLpa1bt2r58uXy+/3Kzc3VkiVL9MQTTwztmQBfYk7+Aw9palCRGsifVGVlZWn9+vVav379Bcdcdtll2rZt22AeGgDwJcTFvoAU4WvjgS9GpIAUiYXDivX2pnoagNWIFJAisZ4eGSIFfC4iBaRIso6kfPPmceIE0haRAlIk1tOTlEi58/PlcLmSMCPAPkQKSJHupiaFL3C9ysFw5eZKDkcSZgTYh0gBKRLt7FSsu3vI28nIzeVbeZG2+M0GRjmOpJDOiBQwyjmzs4kU0haRAlIgmX/I63A45CBSSFNECkgFYxQLh1M9C8B6F/V9UgCGxkSjinz88eeOiRmj3mhURp8cLbk+vTk+/Rn4MiBSQAqYWEzRzs7PHXO8q0s/O3xYp8Nhed1uFWZlacbYsboyP1+X5uQoNyODWCHtESkgFaJRRbu6PndIYXa2fjh1qjr7+vTH3l4Furv13pkz+l1Li0pycrSgpET+K66Qu6BghCYNjDwiBaRApKNDwT17PndMlsulKZ/5WviYMeqKRNTa3a1dra169uhRHcjK0oORiPKM4agKaYkTJ4AUMLHYFx5JncvpcGiM262vjhmjpVOm6P9efbWOHz+u+9es0TvvvMNXfyAtESlglDl7yvmkvDw9fPPN+j/XXaeHH35Yb731FqFC2iFSwCjlcDhUWFyslQ88oL/4i7/Q448/rnfffZdQIa0QKWAUc3o8yvF6dffdd+v6669XdXW1Tp06leppAUlDpIARZoyRiUSSsi2H0ymHy6WcnBz95V/+pYwx2rhxI0dTSBtECkiBSCiU9G0WFxdrzZo12rJlixoaGggV0gKRAlJgOCLlcDg0a9Ys+f1+/du//ZvCXHYJaYBIASkwHJGSPgnV7bffrgMHDqi5uZmjKYx6RAoYRt3d3Yr08/lTuKVl6Bt3OOTKyTlnkUNTpkzRmDFjdOTIkaE/BpBiRAoYRr/5zW/03nvvnbf8TG3tBe/zUVeXtjY366X/+R/9vqVFXX19/Y5z5eTIe8015y0fO3asZs6cqbq6OkWj0YufPGABLosEDJNIJKIXX3xRV155pa655ho5v+Ar3o0xOtbZqcf279eHnZ3qiUbldbs1fexY/eTaa+U+9/5O53lHUpLkcrk0ZcoU/frXv1Y0GlVGBv/MMXpxJAUMkwMHDmj//v1644039OGHH37h+P/p7NS9//VfOhIMqvvTr+gI9vXpv9ratHL3bp3u6UkY73A6P/nq+HM4HA5NnDhRgUBAsVgsSc8GSA0iBQyDWCymXbt26eTJkzp69Kj27t37hcH42aFDCl7grb09p06p5pzPsRwXOJKSpPz8fPX09HDiBEY9IgUMgz/+8Y/at2+fXC6XnE6n3nrrLX38mS85dLjdQ3+Qz4lUZmbmF769CIwGvFkNDIP29nZVVFTI4/GosLBQX/3qV9Xd3a3cT9+em7J2rQ5XVQ3pMRxOp5zZ2f2ui0QivNWHtECkgGFQVlamSy+9VMeOHVN3d7fuuuuu+JGNw+GQo5+TGSpLS7Xv1Cn19fMW3aS8PM0cxJcbhkIhZfDNvUgDvB8ADAO3262srCxddtllampqksvlktvtjkfD7fOp4IYbEu5TUVKix665RlkuV/wfpsvh0DiPRz+99lpNy89PGF+0aFG/j22MUWtrq8aPH89bfhj1OJIChtGsWbP0wgsvqKWlRWVlZfHlzuxsjfX7Fdy3T9GODkmfHGFVlJRoYk6Oth4/rtM9PZqUl6c7Jk/WOI8nYbueCRM01u/v90gpFovpD3/4g77yla8QKYx6RAoYRqWlpcrIyNAHH3yQECmHwyHvNdfokptuUuu//7vMp39063A4NH3sWE0fO/aC28wYO1Yld92ljM98tfxndXZ26tChQ/rWt74ll8uV3CcEjDD+MwsYRtnZ2br66qu1e/du9fb2Jqxzejwq+s53VHDjjf1+RtUfV16eJtxxh/Kvu06OfgJkjFFLS4taWlo0Y8YMPpPCqEekgGGUmZmphQsX6s0339SJEyfO+7slV06OJt5zj4oWLVJmYeEFt+NwuZQ9aZJK771Xl9x0k5yZmf2OM8bo1Vdf1eTJk/WVr3yFSGHUI1LAMHI4HJo5c6YKCgr0yiuv9Ls+IzdXE26/XZNWrtQllZXKnjRJzqysTy4gO2aMcq+8UhO+9z1NeuABFXzzm58bnuPHj2vr1q265ZZblJeXN5xPDRgRfCYFDLNx48bpnnvu0SOPPKIbbrhB06dPPy80To9HedOnK/drX1O0u1smEpGJxeRwueR0u+XMyZHzC94S7Orq0k9/+lNNnz5dN954I0dRSAscSQHDzOFw6IYbbtDChQv1N3/zNzp+/Hi/lytyOBxyejxy5+crc/x4eQoLlTlunDK83i8MVG9vr1566SUdOnRIP/rRj+ROxhUtAAsQKWAEOJ1O3XfffXK73frJT36irq6upF1XzxijXbt26Z//+Z+1fPlylZaWchSFtEGkgBFSWFiotWvX6qOPPtLDDz+sU6dODenSRcYY9fT0aPv27Vq7dq3uvvtu3XLLLZx2jrRCpIAR4nA4dMUVV2jdunU6c+aMVqxYoTfffLPfb+4diEAgoKeeekp/93d/px/+8If6/ve/r8wLnPUHjFYOMwqv5R8KheTz+RQMBuW9wB80ArY6+7dMv/jFL/T73/9eCxcu1NKlS3XppZfG36br7+26s/9Uu7u7VVtbq+eee04Oh0NVVVW64YYb+HJDjCoDfR0nUkAKGGPU19en9957T0899ZSOHTumWbNm6eabb9bkyZOVm5srj8cjh8OhSCSi7u5unT59Wm+//ba2b9+urq4uLV68WN/73vc0ZswYLn+EUYdIAaOAMUbhcFh1dXWqra1VQ0ODurq65Ha74xekjcVi6u3tlTFGZWVl+pM/+RPdeOONCUdewGhDpIBRJhKJ6OTJkzp16pQ6Ojri36zrdruVm5ur/Px8FRYWKi8vjzhh1Bvo6zhvYgOWyMjI0IQJEzRhwoRUTwWwBm9kAwCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWGtQkdqwYYNmzpwpr9crr9crv9+v7du3x9f39PSoqqpK48aNU15enhYtWqTW1taEbTQ1NamyslI5OTkqLCzUgw8+qEgkkpxnAwBIK4OK1MSJE/Xkk0+qvr5e+/bt04033qhvf/vbOnTokCTpgQce0GuvvabNmzertrZWLS0tuu222+L3j0ajqqysVG9vr95++2298MIL2rhxo9auXZvcZwUASA9miMaOHWt++ctfmvb2duN2u83mzZvj644cOWIkmbq6OmOMMdu2bTNOp9MEAoH4mA0bNhiv12vC4fCAHzMYDBpJJhgMDnX6AIAUGOjr+EV/JhWNRrVp0yZ1dXXJ7/ervr5efX19Ki8vj4+ZOnWqysrKVFdXJ0mqq6vTjBkzVFRUFB9TUVGhUCgUPxrrTzgcVigUSrgBANLfoCN18OBB5eXlyePx6L777tOWLVs0bdo0BQIBZWZmKj8/P2F8UVGRAoGAJCkQCCQE6uz6s+supLq6Wj6fL34rLS0d7LQBAKPQoCN1xRVXqKGhQbt379by5cu1ZMkSHT58eDjmFrdmzRoFg8H4rbm5eVgfDwBgh4zB3iEzM1NTpkyRJM2ePVt79+7Vz3/+c91xxx3q7e1Ve3t7wtFUa2uriouLJUnFxcXas2dPwvbOnv13dkx/PB6PPB7PYKcKABjlhvx3UrFYTOFwWLNnz5bb7daOHTvi6xobG9XU1CS/3y9J8vv9OnjwoNra2uJjampq5PV6NW3atKFOBQCQZgZ1JLVmzRrddNNNKisrU0dHh1588UW98cYb+t3vfiefz6dly5Zp9erVKigokNfr1f333y+/36958+ZJkhYsWKBp06bprrvu0rp16xQIBPTII4+oqqqKIyUAwHkGFam2tjbdfffdOnHihHw+n2bOnKnf/e53+tM//VNJ0tNPPy2n06lFixYpHA6roqJCzz77bPz+LpdLW7du1fLly+X3+5Wbm6slS5boiSeeSO6zAgCkBYcxxqR6EoMVCoXk8/kUDAbl9XpTPR0AwCAN9HWca/cBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsNaQIvXkk0/K4XBo1apV8WU9PT2qqqrSuHHjlJeXp0WLFqm1tTXhfk1NTaqsrFROTo4KCwv14IMPKhKJDGUqAIA0dNGR2rt3r37xi19o5syZCcsfeOABvfbaa9q8ebNqa2vV0tKi2267Lb4+Go2qsrJSvb29evvtt/XCCy9o48aNWrt27cU/CwBAejIXoaOjw1x++eWmpqbGXH/99WblypXGGGPa29uN2+02mzdvjo89cuSIkWTq6uqMMcZs27bNOJ1OEwgE4mM2bNhgvF6vCYfDA3r8YDBoJJlgMHgx0wcApNhAX8cv6kiqqqpKlZWVKi8vT1heX1+vvr6+hOVTp05VWVmZ6urqJEl1dXWaMWOGioqK4mMqKioUCoV06NChfh8vHA4rFAol3AAA6S9jsHfYtGmT3n33Xe3du/e8dYFAQJmZmcrPz09YXlRUpEAgEB/z2UCdXX92XX+qq6v1+OOPD3aqAIBRblBHUs3NzVq5cqX+5V/+RVlZWcM1p/OsWbNGwWAwfmtubh6xxwYApM6gIlVfX6+2tjbNmjVLGRkZysjIUG1trZ555hllZGSoqKhIvb29am9vT7hfa2uriouLJUnFxcXnne139uezY87l8Xjk9XoTbgCA9DeoSM2fP18HDx5UQ0ND/DZnzhwtXrw4/r/dbrd27NgRv09jY6Oamprk9/slSX6/XwcPHlRbW1t8TE1Njbxer6ZNm5akpwUASAeD+kxqzJgxmj59esKy3NxcjRs3Lr582bJlWr16tQoKCuT1enX//ffL7/dr3rx5kqQFCxZo2rRpuuuuu7Ru3ToFAgE98sgjqqqqksfjSdLTAgCkg0GfOPFFnn76aTmdTi1atEjhcFgVFRV69tln4+tdLpe2bt2q5cuXy+/3Kzc3V0uWLNETTzyR7KkAAEY5hzHGpHoSgxUKheTz+RQMBvl8CgBGoYG+jnPtPgCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtTJSPYGLYYyRJIVCoRTPBABwMc6+fp99Pb+QURmp06dPS5JKS0tTPBMAwFB0dHTI5/NdcP2ojFRBQYEkqamp6XOf3JddKBRSaWmpmpub5fV6Uz0da7GfBob9NDDsp4Exxqijo0MlJSWfO25URsrp/OSjNJ/Pxy/BAHi9XvbTALCfBob9NDDspy82kIMMTpwAAFiLSAEArDUqI+XxePTYY4/J4/GkeipWYz8NDPtpYNhPA8N+Si6H+aLz/wAASJFReSQFAPhyIFIAAGsRKQCAtYgUAMBaozJS69ev16RJk5SVlaW5c+dqz549qZ7SiNq1a5duueUWlZSUyOFw6OWXX05Yb4zR2rVrNWHCBGVnZ6u8vFwffPBBwpgzZ85o8eLF8nq9ys/P17Jly9TZ2TmCz2J4VVdX69prr9WYMWNUWFioW2+9VY2NjQljenp6VFVVpXHjxikvL0+LFi1Sa2trwpimpiZVVlYqJydHhYWFevDBBxWJREbyqQyrDRs2aObMmfE/PPX7/dq+fXt8Pfuof08++aQcDodWrVoVX8a+GiZmlNm0aZPJzMw0v/rVr8yhQ4fMvffea/Lz801ra2uqpzZitm3bZh5++GHzH//xH0aS2bJlS8L6J5980vh8PvPyyy+b9957z/z5n/+5mTx5sunu7o6PWbhwobnqqqvMO++8Y958800zZcoUc+edd47wMxk+FRUV5vnnnzfvv/++aWhoMH/2Z39mysrKTGdnZ3zMfffdZ0pLS82OHTvMvn37zLx588zXv/71+PpIJGKmT59uysvLzf79+822bdvM+PHjzZo1a1LxlIbFq6++an7zm9+Y//7v/zaNjY3mr//6r43b7Tbvv/++MYZ91J89e/aYSZMmmZkzZ5qVK1fGl7Ovhseoi9R1111nqqqq4j9Ho1FTUlJiqqurUzir1Dk3UrFYzBQXF5sf//jH8WXt7e3G4/GYl156yRhjzOHDh40ks3fv3viY7du3G4fDYT766KMRm/tIamtrM5JMbW2tMeaTfeJ2u83mzZvjY44cOWIkmbq6OmPMJ/8x4HQ6TSAQiI/ZsGGD8Xq9JhwOj+wTGEFjx441v/zlL9lH/ejo6DCXX365qampMddff308Uuyr4TOq3u7r7e1VfX29ysvL48ucTqfKy8tVV1eXwpnZ49ixYwoEAgn7yOfzae7cufF9VFdXp/z8fM2ZMyc+pry8XE6nU7t37x7xOY+EYDAo6f9fnLi+vl59fX0J+2nq1KkqKytL2E8zZsxQUVFRfExFRYVCoZAOHTo0grMfGdFoVJs2bVJXV5f8fj/7qB9VVVWqrKxM2CcSv0/DaVRdYPbUqVOKRqMJ/ydLUlFRkY4ePZqiWdklEAhIUr/76Oy6QCCgwsLChPUZGRkqKCiIj0knsVhMq1at0je+8Q1Nnz5d0if7IDMzU/n5+Qljz91P/e3Hs+vSxcGDB+X3+9XT06O8vDxt2bJF06ZNU0NDA/voMzZt2qR3331Xe/fuPW8dv0/DZ1RFCrgYVVVVev/99/XWW2+leipWuuKKK9TQ0KBgMKhf//rXWrJkiWpra1M9Las0Nzdr5cqVqqmpUVZWVqqn86Uyqt7uGz9+vFwu13lnzLS2tqq4uDhFs7LL2f3wefuouLhYbW1tCesjkYjOnDmTdvtxxYoV2rp1q15//XVNnDgxvry4uFi9vb1qb29PGH/ufupvP55dly4yMzM1ZcoUzZ49W9XV1brqqqv085//nH30GfX19Wpra9OsWbOUkZGhjIwM1dbW6plnnlFGRoaKiorYV8NkVEUqMzNTs2fP1o4dO+LLYrGYduzYIb/fn8KZ2WPy5MkqLi5O2EehUEi7d++O7yO/36/29nbV19fHx+zcuVOxWExz584d8TkPB2OMVqxYoS1btmjnzp2aPHlywvrZs2fL7XYn7KfGxkY1NTUl7KeDBw8mBL2mpkZer1fTpk0bmSeSArFYTOFwmH30GfPnz9fBgwfV0NAQv82ZM0eLFy+O/2/21TBJ9Zkbg7Vp0ybj8XjMxo0bzeHDh80PfvADk5+fn3DGTLrr6Ogw+/fvN/v37zeSzFNPPWX2799v/vd//9cY88kp6Pn5+eaVV14xBw4cMN/+9rf7PQX9mmuuMbt37zZvvfWWufzyy9PqFPTly5cbn89n3njjDXPixIn47eOPP46Pue+++0xZWZnZuXOn2bdvn/H7/cbv98fXnz1leMGCBaahocH89re/NZdccklanTL80EMPmdraWnPs2DFz4MAB89BDDxmHw2H+8z//0xjDPvo8nz27zxj21XAZdZEyxpi///u/N2VlZSYzM9Ncd9115p133kn1lEbU66+/biSdd1uyZIkx5pPT0B999FFTVFRkPB6PmT9/vmlsbEzYxunTp82dd95p8vLyjNfrNUuXLjUdHR0peDbDo7/9I8k8//zz8THd3d3mhz/8oRk7dqzJyckx3/nOd8yJEycStvPhhx+am266yWRnZ5vx48ebv/qrvzJ9fX0j/GyGzz333GMuu+wyk5mZaS655BIzf/78eKCMYR99nnMjxb4aHnxVBwDAWqPqMykAwJcLkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANb6f6zbQYAfDG6OAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-126.85454989590353"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
